{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fine-Tuning a BERT Model and Create a Text Classifier\n",
    "\n",
    "We have already performed the Feature Engineering to create BERT embeddings from the `reviews_body` text using the pre-trained BERT model, and split the dataset into train, validation and test files. To optimize for Tensorflow training, we saved the files in TFRecord format. \n",
    "\n",
    "Now, let’s fine-tune the BERT model to our Customer Reviews Dataset and add a new classification layer to predict the `star_rating` for a given `review_body`.\n",
    "\n",
    "![BERT Training](img/bert_training.png)\n",
    "\n",
    "As mentioned earlier, BERT’s attention mechanism is called a Transformer. This is, not coincidentally, the name of the popular BERT Python library, “Transformers,” maintained by a company called [HuggingFace](https://github.com/huggingface/transformers). We will use a variant of BERT called [DistilBert](https://arxiv.org/pdf/1910.01108.pdf) which requires less memory and compute, but maintains very good accuracy on our dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DEMO 3: \n",
    "\n",
    "# Run Model Training on Amazon SageMaker\n",
    "\n",
    "Amazon SageMaker is a fully managed service that provides every developer and data scientist with the ability to build, train, and deploy machine learning (ML) models quickly. SageMaker removes the heavy lifting from each step of the machine learning process to make it easier to develop high quality models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q sagemaker==2.16.3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "import sagemaker\n",
    "\n",
    "session = sagemaker.Session()\n",
    "bucket = session.default_bucket()\n",
    "role = sagemaker.get_execution_role()\n",
    "region = boto3.Session().region_name"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Specify the Dataset in S3\n",
    "We are using the train, validation, and test splits created in the previous section."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "processed_train_data_s3_uri = \"s3://fsx-container-demo/input/data/train\"\n",
    "\n",
    "!aws s3 ls $processed_train_data_s3_uri/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "processed_validation_data_s3_uri = \"s3://fsx-container-demo/input/data/validation\"\n",
    "\n",
    "!aws s3 ls $processed_validation_data_s3_uri/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "processed_test_data_s3_uri = \"s3://fsx-container-demo/input/data/test\"\n",
    "\n",
    "!aws s3 ls $processed_test_data_s3_uri/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Specify S3 `Distribution Strategy`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from sagemaker.inputs import TrainingInput\n",
    "\n",
    "s3_input_train_data = TrainingInput(s3_data=processed_train_data_s3_uri, distribution=\"ShardedByS3Key\")\n",
    "s3_input_validation_data = TrainingInput(s3_data=processed_validation_data_s3_uri, distribution=\"ShardedByS3Key\")\n",
    "s3_input_test_data = TrainingInput(s3_data=processed_test_data_s3_uri, distribution=\"ShardedByS3Key\")\n",
    "\n",
    "print(s3_input_train_data.config)\n",
    "print(s3_input_validation_data.config)\n",
    "print(s3_input_test_data.config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup Hyper-Parameters for Classification Layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 3\n",
    "learning_rate = 0.00001\n",
    "epsilon = 0.00000001\n",
    "train_batch_size = 128\n",
    "validation_batch_size = 64\n",
    "test_batch_size = 64\n",
    "train_steps_per_epoch = 100\n",
    "validation_steps = 10\n",
    "test_steps = 10\n",
    "use_xla = True\n",
    "use_amp = True\n",
    "max_seq_length = 64\n",
    "freeze_bert_layer = True\n",
    "run_validation = True\n",
    "run_test = True\n",
    "run_sample_predictions = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup Metrics To Track Model Performance\n",
    "\n",
    "These sample log lines...\n",
    "```\n",
    "45/50 [=====>..] - ETA: 3s - loss: 0.425 - accuracy: 0.881\n",
    "50/50 [=======>] - ETA: 0s - val_loss: 0.407 - val_accuracy: 0.885\n",
    "```\n",
    "...will produce the following 4 metrics in CloudWatch:\n",
    "\n",
    "`loss` = 0.425\n",
    "\n",
    "`accuracy` = 0.881\n",
    "\n",
    "`val_loss` = 0.407\n",
    "\n",
    "`val_accuracy` = 0.885"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"img/cloudwatch_train_accuracy.png\" width=\"50%\" align=\"left\">\n",
    "\n",
    "<img src=\"img/cloudwatch_train_loss.png\" width=\"50%\" align=\"left\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics_definitions = [\n",
    "    {\"Name\": \"train:loss\", \"Regex\": \"loss: ([0-9\\\\.]+)\"},\n",
    "    {\"Name\": \"train:accuracy\", \"Regex\": \"accuracy: ([0-9\\\\.]+)\"},\n",
    "    {\"Name\": \"validation:loss\", \"Regex\": \"val_loss: ([0-9\\\\.]+)\"},\n",
    "    {\"Name\": \"validation:accuracy\", \"Regex\": \"val_accuracy: ([0-9\\\\.]+)\"},\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup Our BERT + TensorFlow Script to Run on SageMaker\n",
    "Prepare our TensorFlow model to run on the managed SageMaker service"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!pygmentize ./code/train.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.tensorflow import TensorFlow\n",
    "\n",
    "estimator = TensorFlow(\n",
    "    entry_point=\"train.py\",\n",
    "    source_dir=\"code\",\n",
    "    role=role,\n",
    "    instance_count=1,\n",
    "    instance_type=\"ml.c5.9xlarge\",\n",
    "    use_spot_instances=True,\n",
    "    max_run=3600,\n",
    "    max_wait=3600,\n",
    "    volume_size=1024,\n",
    "    py_version=\"py3\",\n",
    "    framework_version=\"2.1.0\",\n",
    "    hyperparameters={\n",
    "        \"epochs\": epochs,\n",
    "        \"learning_rate\": learning_rate,\n",
    "        \"epsilon\": epsilon,\n",
    "        \"train_batch_size\": train_batch_size,\n",
    "        \"validation_batch_size\": validation_batch_size,\n",
    "        \"test_batch_size\": test_batch_size,\n",
    "        \"train_steps_per_epoch\": train_steps_per_epoch,\n",
    "        \"validation_steps\": validation_steps,\n",
    "        \"test_steps\": test_steps,\n",
    "        \"use_xla\": use_xla,\n",
    "        \"use_amp\": use_amp,\n",
    "        \"max_seq_length\": max_seq_length,\n",
    "        \"freeze_bert_layer\": freeze_bert_layer,\n",
    "        \"run_validation\": run_validation,\n",
    "        \"run_test\": run_test,\n",
    "        \"run_sample_predictions\": run_sample_predictions,\n",
    "    },\n",
    "    input_mode=\"Pipe\",\n",
    "    metric_definitions=metrics_definitions,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train the Model on SageMaker"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "estimator.fit(\n",
    "    inputs={\"train\": s3_input_train_data, \"validation\": s3_input_validation_data, \"test\": s3_input_test_data},\n",
    "    wait=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_job_name = estimator.latest_training_job.name\n",
    "print(\"Training Job Name:  {}\".format(training_job_name))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "display(\n",
    "    HTML(\n",
    "        '<b>Review <a target=\"blank\" href=\"https://console.aws.amazon.com/sagemaker/home?region={}#/jobs/{}\">Training Job</a> After About 5 Minutes</b>'.format(\n",
    "            region, training_job_name\n",
    "        )\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "display(\n",
    "    HTML(\n",
    "        '<b>Review <a target=\"blank\" href=\"https://console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/TrainingJobs;prefix={};streamFilter=typeLogStreamPrefix\">CloudWatch Logs</a> After About 5 Minutes</b>'.format(\n",
    "            region, training_job_name\n",
    "        )\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "display(\n",
    "    HTML(\n",
    "        '<b>Review <a target=\"blank\" href=\"https://s3.console.aws.amazon.com/s3/buckets/{}/{}/?region={}&tab=overview\">S3 Output Data</a> After The Training Job Has Completed</b>'.format(\n",
    "            bucket, training_job_name, region\n",
    "        )\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "estimator.latest_training_job.wait(logs=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Wait Until the ^^ Training Job ^^ Completes Above!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!aws s3 ls s3://$bucket/$training_job_name/output/"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
